"""
Bar plot showing the averaged explained variance statistics
for HRET and HYCOM during SWOT SCIENCE orbit

B. Yadidya
Feb 23, 2026
"""

import sys
sys.path.append('.')
from common_imports import *

plt.rcParams.update({
    'font.family': 'sans-serif',
    'font.sans-serif': ['Myriad Pro', 'DejaVu Sans', 'Arial'],
    'font.size': 9,                 
    'axes.labelsize': 8,            
    'axes.titlesize': 9,            
    'xtick.labelsize': 8,           
    'ytick.labelsize': 8,           
    'legend.fontsize': 9,
    'figure.titlesize': 9,
    'axes.linewidth': 0.5,          
    'xtick.major.width': 0.5,       
    'ytick.major.width': 0.5,
    'grid.linewidth': 0.5,
    'lines.linewidth': 1.0,      
    'figure.dpi': 500
})

categories = [
    'HRET\n(5 major tides)',
    'HYCOM\n(5 major tides)',
    'HYCOM\n(8 major tides)',
    'HYCOM\n(Total\nInternal Tide)',
    # 'HYCOM\n(Total \nInternal Tide \n& masked eddies)'
    
]
num_groups = len(categories)

explained_ssh     = [20.46, 24.28, 25.23, 32.57]
explained_xt_slope = [22.35, 31.17, 32.53, 35.68]
explained_at_slope = [25.24, 30.00, 31.37, 33.12]

bar_colors = [
    "#989898", "#1b9e77", "#7570b3", "#d95f02"
]
bar_colors = [
    "#989898", # Gray
    "#1b9e77", # Green
    "#7570b3", # Purple
    "#d95f02", # Orange
    "#e7298a"  # Magenta/Pink
]

# --- 3. Create Figure (Same as before) ---
fig_width_cm = 9
fig_width_in = fig_width_cm / 2.54
fig_height_max_cm = 13
fig_height_in = fig_height_max_cm / 2.54 

fig, axs = plt.subplots(3, 1, figsize=(fig_width_in, fig_height_in),
                        constrained_layout=True, sharex=True)

data_list = [
    (explained_ssh,         "(A) Explained variance in SSH"),
    (explained_xt_slope,    "(B) Explained variance in cross-track slope"),
    (explained_at_slope,    "(C) Explained variance in along-track slope")
]

bar_width = 0.55
x = np.arange(num_groups)

for i, (y_vals, row_title) in enumerate(data_list):
    ax = axs[i]
    bars = ax.bar(
        x, y_vals, width=bar_width, color=bar_colors, 
        edgecolor='black', zorder=2
    )

    # --- NEW: Add a subtle background shade to group the first three bars ---
    # The shade spans from the left edge of the first bar to the right edge of the third.
    ax.axvspan(-bar_width/2, 2 + bar_width/2, color='gray', alpha=0.1, zorder=0)

    # Add a label for the shaded region on the first plot only
    if i == 0:
        ax.text(1, 40, 'Phase-locked Internal Tides', 
                ha='center', va='center', fontsize=9, style='italic',
                bbox=dict(facecolor='white', alpha=0.0, edgecolor='none'))

    hret_val = y_vals[0]
    # Add value labels above bars; add pct increase for non-HRET
    for j, bar in enumerate(bars):
        yval = bar.get_height()
        if j == 0:
            text = f"{yval:.2f}"
        else:
            pct_inc = (yval - hret_val) / hret_val * 100
            text = f"{yval:.2f}\n({pct_inc:+.0f}%)"
        ax.text(
            bar.get_x() + bar.get_width()/2, yval + 0.9, text,
            ha='center', va='bottom', fontsize=8, fontweight='medium'
        )
    ax.set_ylabel("")
    ax.set_ylim(0,45)
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.spines['left'].set_linewidth(1.3)
    ax.spines['bottom'].set_linewidth(1.3)
    ax.tick_params(axis='y', which='both', length=4, width=1.2, labelsize=8)
    ax.set_title(row_title, pad=8, fontweight='bold', loc='left')
    ax.grid(False)

axs[-1].set_xticks(x)
axs[-1].set_xticklabels(categories, rotation=0, ha='center')
axs[-1].tick_params(axis='x', length=0, labelsize=8)


axs[0].set_ylabel(r'mm$^2$')
axs[1].set_ylabel(r'$10^{-3}$ (mm/km)$^{2}$')
axs[2].set_ylabel(r'$10^{-3}$ (mm/km)$^{2}$')


plt.savefig('Fig1_stats_summary_bar_plot.pdf', bbox_inches='tight')
